import argparse
import os
from pathlib import Path

import mmcv
from mmcv import Config, DictAction

from mmdet.core.utils import mask2ndarray
from mmdet.core.visualization import imshow_det_bboxes
from mmdet.datasets import build_dataset, get_loading_pipeline


def parse_args():
    parser = argparse.ArgumentParser(description="Browse a dataset")
    parser.add_argument("config", help="train config file path")
    parser.add_argument(
        "--skip-type",
        type=str,
        nargs="+",
        default=["DefaultFormatBundle", "Normalize", "Collect"],
        help="skip some useless pipeline",
    )
    parser.add_argument(
        "--output-dir",
        default=None,
        type=str,
        help="If there is no display interface, you can save it",
    )
    parser.add_argument("--not-show", default=False, action="store_true")
    parser.add_argument(
        "--show-interval", type=float, default=2, help="the interval of show (s)"
    )
    parser.add_argument(
        "--cfg-options",
        nargs="+",
        action=DictAction,
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file. If the value to "
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        "Note that the quotation marks are necessary and that no white space "
        "is allowed.",
    )
    args = parser.parse_args()
    return args


def retrieve_data_cfg(config_path, skip_type, cfg_options):
    cfg = Config.fromfile(config_path)
    if cfg_options is not None:
        cfg.merge_from_dict(cfg_options)
    # import modules from string list.
    if cfg.get("custom_imports", None):
        from mmcv.utils import import_modules_from_strings

        import_modules_from_strings(**cfg["custom_imports"])
    train_data_cfg = cfg.data.train
    train_data_cfg["pipeline"] = [
        x for x in train_data_cfg.pipeline if x["type"] not in skip_type
    ]

    return cfg


def main():
    args = parse_args()
    cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options)
    cfg.data.test.pipeline = get_loading_pipeline(cfg.data.train.pipeline)

    dataset = build_dataset(cfg.data.test)

    progress_bar = mmcv.ProgressBar(len(dataset))

    for item in dataset:
        import pdb

        pdb.set_trace()
        filename = (
            os.path.join(args.output_dir, Path(item["filename"]).name)
            if args.output_dir is not None
            else None
        )
        #  filename = os.path.join(args.output_dir,item['img_metas'][0].data["ori_filename"])

        gt_masks = item.get("gt_masks", None)
        if gt_masks is not None:
            gt_masks = mask2ndarray(gt_masks)

        imshow_det_bboxes(
            item["img"],
            item["gt_bboxes"],
            item["gt_labels"],
            gt_masks,
            class_names=dataset.CLASSES,
            show=not args.not_show,
            wait_time=args.show_interval,
            out_file=filename,
            bbox_color=(255, 102, 61),
            text_color=(255, 102, 61),
        )

        progress_bar.update()


if __name__ == "__main__":
    main()
